from django.http import JsonResponse
from django.shortcuts import render
import pandas as pd
from rest_framework import status
from django.views.decorators.csrf import csrf_exempt
from django.db import IntegrityError
from rest_framework import generics, viewsets
from rest_framework.decorators import api_view
from rest_framework.pagination import PageNumberPagination
from rest_framework.parsers import MultiPartParser, FormParser
from .models import EnvironmentalData, Company, CompanyData, ShareholderData, FollowedCompany
from .serializers import EnvironmentalDataSerializer, CompanySerializer, CompanyDataSerializer, \
    ShareholderDataSerializer, SecuritiesSerializer, CompanyGroupBySerializer, FollowedCompanySerializer
from rest_framework.views import APIView
from rest_framework.response import Response
from django.db.models import Count, Avg
from .models import Securities
MODEL_FIELD_MAPPING = {
    'ESG评级': {
        '序号': 'serial_number',
        '公司名称': 'name',
        '评级日期': 'rating_date',
        'ESG级别': 'esg_rating',
        'E维度级别': 'e_dimension_rating',
        'S维度级别': 's_dimension_rating',
        'G维度级别': 'g_dimension_rating',
        '所属省': 'province',
        '所属市': 'city',
        '所属县': 'district',
        '行业': 'industry',
        '成立日期': 'establishment_date',
        '注册资本': 'registered_capital',
    },
    '环境保护': {
        '序号': 'serial',
        '股票简称': 'stock_name',
        '股票代码': 'stock_code',
        '报告期': 'report_period',
        '碳排放总量(范围一、二)(吨CO2e)': 'carbon_emission_total',
        '单位营收碳排放量(范围一、二)(吨CO2e/百万元)': 'carbon_emission_per_revenue',
        '用水量': 'water_usage',
        '单位营收用水量': 'water_usage_per_revenue',
        '废弃物总量': 'waste_total',
        '单位营收废弃物总量': 'waste_per_revenue',
        '综合耗能': 'energy_consumption',
        '单位营收综合耗能': 'energy_consumption_per_revenue',
        '环境保护投入': 'environmental_protection_investment',
    },
    '社会责任': {
        '序号': 'serial_number',
        '股票简称': 'stock_abbreviation',
        '股票代码': 'stock_code',
        '报告期': 'reporting_period',
        '员工人数': 'employee_count',
        '研发费用(万元)': 'r_and_d_expenses',
        '研发费用增长率(%)': 'r_and_d_expenses_growth_rate',
        '研发费用占营收比(%)': 'r_and_d_expenses_to_revenue_ratio',
        '公益投入金额': 'public_welfare_investment',
        '扶贫投入金额': 'poverty_alleviation_investment',
        '员工薪酬总额': 'total_employee_compensation',
        '员工人均薪酬': 'average_employee_compensation',
    },
    '公司治理': {
        '序号': 'serial_number',
        '股票简称': 'stock_abbreviation',
        '股票代码': 'stock_code',
        '报告期': 'reporting_period',
        '实际控制人名称': 'actual_controller_name',
        '实际控制人性质': 'actual_controller_nature',
        '股东人数(人)': 'shareholder_count',
        '第一大股东持股数量(万股)': 'largest_shareholder_quantity',
        '第一大股东持股比例(%)': 'largest_shareholder_percentage',
        '前十大股东持股数量合计(万股)': 'top_ten_shareholders_quantity',
        '前十大股东持股占比(%)': 'top_ten_shareholders_percentage',
        '总股份(万股)': 'total_shares',
        '现金分红率': 'cash_dividend_rate',
        '每股现金分红金额': 'per_share_cash_dividend_amount',
        '董事会人数': 'board_members_count',
        '女性董事人数': 'female_directors_count',
        '女性董事比例': 'female_directors_percentage',
        '现金分红总额': 'total_cash_dividends',
    },
}

@csrf_exempt
def upload_excel(request):
    if request.method == 'POST':
        try:
            # 从请求中获取上传的文件
            uploaded_file = request.FILES['excel_file']

            # 使用pandas读取Excel文件数据
            excel_data = pd.read_excel(uploaded_file, sheet_name=None)

            # 中文表格名称与模型类的映射关系
            model_mapping = {
                '环境保护': EnvironmentalData,
                'ESG评级': Company,
                '社会责任': CompanyData,
                '公司治理': ShareholderData,
            }

            # 遍历Excel文件中的每个表格
            for sheet_name, df in excel_data.items():
                # 根据表格名称选择相应的模型进行数据保存
                model = model_mapping.get(sheet_name)
                if model is None:
                    continue  # 跳过不识别的表格
                for index, row in df.iterrows():
                    try:
                        # 直接使用模型的字段名将中文字段名转换为模型字段名
                        model_data = {model._meta.get_field(MODEL_FIELD_MAPPING[sheet_name][col]).name: row[col] for col in df.columns}
                        model.objects.create(**model_data)
                    except IntegrityError as e:
                        # 处理重复数据，可以根据实际情况选择忽略或执行其他操作
                        pass

            # 上传成功，返回 JSON 响应
            return JsonResponse({
                'status': 'success',
                'message': '上传成功',
            })

        except Exception as e:
            # 文件处理或保存出现异常，返回失败的 JSON 响应
            return JsonResponse({
                'status': 'error',
                'message': f'文件上传失败: {str(e)}',
            })

    # 非POST请求，返回 JSON 响应
    return JsonResponse({
        'status': 'error',
        'message': '非法请求方法',
    })

class EnvironmentalDataViewSet(viewsets.ModelViewSet):
    queryset = EnvironmentalData.objects.all()
    serializer_class = EnvironmentalDataSerializer

class CompanyViewSet(viewsets.ModelViewSet):
    queryset = Company.objects.all()
    serializer_class = CompanySerializer

class CompanyDataViewSet(viewsets.ModelViewSet):
    queryset = CompanyData.objects.all()
    serializer_class = CompanyDataSerializer

class ShareholderDataViewSet(viewsets.ModelViewSet):
    queryset = ShareholderData.objects.all()
    serializer_class = ShareholderDataSerializer


class EnvironmentalDataList(generics.ListCreateAPIView):
    queryset = EnvironmentalData.objects.all()
    serializer_class = EnvironmentalDataSerializer

class EnvironmentalDataDetail(generics.RetrieveUpdateDestroyAPIView):
    queryset = EnvironmentalData.objects.all()
    serializer_class = EnvironmentalDataSerializer

class CompanyList(generics.ListCreateAPIView):
    queryset = Company.objects.all()
    serializer_class = CompanySerializer

class CompanyDetail(generics.RetrieveUpdateDestroyAPIView):
    queryset = Company.objects.all()
    serializer_class = CompanySerializer

class CompanyDataList(generics.ListCreateAPIView):
    queryset = CompanyData.objects.all()
    serializer_class = CompanyDataSerializer

class CompanyDataDetail(generics.RetrieveUpdateDestroyAPIView):
    queryset = CompanyData.objects.all()
    serializer_class = CompanyDataSerializer

class ShareholderDataList(generics.ListCreateAPIView):
    queryset = ShareholderData.objects.all()
    serializer_class = ShareholderDataSerializer

class ShareholderDataDetail(generics.RetrieveUpdateDestroyAPIView):
    queryset = ShareholderData.objects.all()
    serializer_class = ShareholderDataSerializer

class CompanyCountByESGView(APIView):
    def get(self, request):
        esg_counts = Company.objects.values('esg_rating').annotate(count=Count('esg_rating'))
        return Response(esg_counts)
class OverallRatingCountView(APIView):
    def get(self, request):
        overall_rating_counts = Securities.objects.values('overall_rating').annotate(count=Count('overall_rating'))
        return Response(overall_rating_counts)
class SRatingCountView(APIView):
    def get(self, request):
        s_rating_counts = Securities.objects.values('s_rating').annotate(count=Count('s_rating'))
        return Response(s_rating_counts)
class ERatingCountView(APIView):
    def get(self, request):
        e_rating_counts = Securities.objects.values('e_rating').annotate(count=Count('e_rating'))
        return Response(e_rating_counts)

class GRatingCountView(APIView):
    def get(self, request):
        g_rating_counts = Securities.objects.values('g_rating').annotate(count=Count('g_rating'))
        return Response(g_rating_counts)
class OverallScoreByYearView(APIView):
    def get(self, request):
        overall_score_averages = Securities.objects.values('year').annotate(
            overall_score_avg=Avg('overall_score')
        )

        return Response(overall_score_averages)
class TopOverallScoresView(APIView):
    def get(self, request):
        top_scores = Securities.objects.order_by('-overall_score')[:10]
        serializer = SecuritiesSerializer(top_scores, many=True)
        return Response(serializer.data)
class CompanyByProvinceView(APIView):
    def get(self, request):
        company_counts = Company.objects.values('province').annotate(count=Count('id'))
        serializer = CompanyGroupBySerializer(company_counts, many=True)
        return Response(serializer.data)
# class CompanyPagination(PageNumberPagination):
#     page_size = 10  # 每页显示的记录数
#     page_size_query_param = 'page_size'
#     max_page_size = 100

class CompanyViewSet(viewsets.ModelViewSet):
    queryset = Company.objects.all()
    serializer_class = CompanySerializer
    # pagination_class = CompanyPagination

def average_e_score_by_year(request):
    # 使用annotate函数和Avg函数对E得分按年度进行平均分组
    queryset = Securities.objects.values('year').annotate(avg_e_score=Avg('e_score'))

    # 将结果整理成字典格式
    result = {entry['year']: entry['avg_e_score'] for entry in queryset}

    # 返回JSON格式的响应
    return JsonResponse(result)

def average_s_score_by_year(request):
    queryset = Securities.objects.values('year').annotate(avg_s_score=Avg('s_score'))
    result = {entry['year']: entry['avg_s_score'] for entry in queryset}
    return JsonResponse(result)

def average_g_score_by_year(request):
    queryset = Securities.objects.values('year').annotate(avg_g_score=Avg('g_score'))
    result = {entry['year']: entry['avg_g_score'] for entry in queryset}
    return JsonResponse(result)

class OverallRatingTopThree(generics.ListAPIView):
    queryset = Securities.objects.order_by('-overall_score')[:3]
    serializer_class = SecuritiesSerializer

class ERatingTopThree(generics.ListAPIView):
    queryset = Securities.objects.order_by('-e_score')[:3]
    serializer_class = SecuritiesSerializer

class SRatingTopThree(generics.ListAPIView):
    queryset = Securities.objects.order_by('-s_score')[:3]
    serializer_class = SecuritiesSerializer

class GRatingTopThree(generics.ListAPIView):
    queryset = Securities.objects.order_by('-g_score')[:3]
    serializer_class = SecuritiesSerializer

class OverallRatingTopSix(generics.ListAPIView):
    queryset = Securities.objects.order_by('-overall_score')[:6]
    serializer_class = SecuritiesSerializer

class ERatingTopSix(generics.ListAPIView):
    queryset = Securities.objects.order_by('-e_score')[:6]
    serializer_class = SecuritiesSerializer

class SRatingTopSix(generics.ListAPIView):
    queryset = Securities.objects.order_by('-s_score')[:6]
    serializer_class = SecuritiesSerializer

class GRatingTopSix(generics.ListAPIView):
    queryset = Securities.objects.order_by('-g_score')[:6]
    serializer_class = SecuritiesSerializer
class FollowedCompanyViewSet(viewsets.ModelViewSet):
    queryset = FollowedCompany.objects.all()
    serializer_class = FollowedCompanySerializer

class UploadSecuritiesExcel(APIView):
    parser_classes = [MultiPartParser, FormParser]

    def post(self, request, *args, **kwargs):
        file = request.data.get('file')

        if file is None or not file.name.endswith('.xlsx'):
            return Response({'error': '请上传格式 (.xlsx).'}, status=status.HTTP_400_BAD_REQUEST)

        try:
            # Read the Excel file into a DataFrame
            df = pd.read_excel(file)

            # Map Excel column names to database field names
            column_mapping = {
                '证券代码': 'security_code',
                '证券简称': 'security_name',
                '综合评级': 'overall_rating',
                '年度': 'year',
                '综合得分': 'overall_score',
                'E评级': 'e_rating',
                'E得分': 'e_score',
                'S评级': 's_rating',
                'S得分': 's_score',
                'G评级': 'g_rating',
                'G得分': 'g_score',
            }

            # Rename DataFrame columns based on the mapping
            df.rename(columns=column_mapping, inplace=True)

            # Convert DataFrame to a list of dictionaries
            securities_data = df.to_dict(orient='records')

            # Bulk create Securities objects
            Securities.objects.bulk_create([Securities(**data) for data in securities_data])

            return Response({'success': '上传文件成功'}, status=status.HTTP_201_CREATED)
        except Exception as e:
            return Response({'error': f'上传文件错误: {str(e)}'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

@api_view(['GET'])
def user_followed_companies(request, user_id):
    followed_companies = FollowedCompany.objects.filter(user_id=user_id)
    serializer = FollowedCompanySerializer(followed_companies, many=True)
    return JsonResponse(serializer.data, safe=False)